//WriteBack.scala
package mycore

import chisel3._
import chisel3.util._

class WriteBack extends Module{
  val io = IO(new Bundle{

    val ENABLE = Input(Bool())
    val memwb = Flipped(new MEMWB)
    val wbdata = new WBDATA
    val gpr = Input(Vec(32, UInt(64.W)))

    //difftest
    val InstrCommit = Flipped(new DiffInstrCommitIO)
    val ArchIntRegState = Flipped(new DiffArchIntRegStateIO)
    val CSRState = Flipped(new DiffCSRStateIO)
    val TrapEvent = Flipped(new DiffTrapEventIO)
    val ArchFpRegState = Flipped(new DiffArchFpRegStateIO)
    val ArchEvent = Flipped(new DiffArchEventIO)

    val isSoC = Input(Bool())
    
  })

  //本流水级需要使用的信号
  val wbs = RegEnable(io.memwb, WireInit(0.U.asTypeOf(new MEMWB())), io.ENABLE)
  val Valid = wbs.Valid
  val Inst = wbs.Inst
  val PC = wbs.PC
  val isa = wbs.isa
  val src1 = wbs.src1
  val src2 = wbs.src2
  val imm = wbs.imm
  val wen = wbs.wen
  val regdes = wbs.regdes
  val aluresult = wbs.aluresult
  val branch = wbs.branch
  val target = wbs.target
  val link = wbs.link
  val auipc = wbs.auipc
  val loadData = wbs.loadData
  val csrData = wbs.csrData
  val clintValid = wbs.clintValid
  val TimerInterrupt = wbs.TimerInterrupt
  val EnvironmentCall = wbs.EnvironmentCall
  val csr = wbs.csr
  /*------------------------------*/
  /*             本体             */
  /*------------------------------*/
  val cycleCnt = RegInit(0.U(64.W))
  val instrCnt = RegInit(0.U(64.W))
  cycleCnt := cycleCnt + 1.U
  when(io.ENABLE && Valid) { instrCnt := instrCnt + 1.U }

  val wbdata = aluresult | link | auipc | loadData | csrData

  io.wbdata.wen := Valid && wen
  io.wbdata.regdes := regdes
  io.wbdata.data := wbdata

  /*------------------------------*/

  //传递给下一个流水级的信号
  //无

  //difftest
  /*--------------------------Putch---------------------------*/
  val isPrint = Inst(6,0) === "h7b".U(7.W)
  val isHalt = Inst(6,0) === "h6b".U(7.W)

  when(~io.isSoC){
    when(io.ENABLE && Valid && isPrint){
      printf("%c", io.gpr(10))
    }
  }

  /*--------------------------Difftest---------------------------*/
  val isMMIO = clintValid
  val isCSR = isa.CSRRW || isa.CSRRS || isa.CSRRC || isa.CSRRWI || isa.CSRRSI || isa.CSRRCI
  val vis_mcycle  = isCSR && (Inst(31,20) === "hb00".U(12.W))
  val vis_mip     = isCSR && (Inst(31,20) === "h344".U(12.W))
  val TimerInterruptReg = RegEnable(TimerInterrupt, false.B, io.ENABLE)
  val DifftestValid = io.ENABLE && (RegEnable(Valid, false.B, io.ENABLE) || TimerInterruptReg)
  
  io.InstrCommit.clock := clock
  io.InstrCommit.coreid := 0.U
  io.InstrCommit.index := 0.U
  io.InstrCommit.valid := DifftestValid && ~TimerInterruptReg
  io.InstrCommit.pc := RegEnable(PC, 0.U, io.ENABLE)
  io.InstrCommit.instr := RegEnable(Inst, 0.U, io.ENABLE)
  io.InstrCommit.skip := DifftestValid && RegEnable(isPrint || vis_mcycle || isMMIO || vis_mip, false.B, io.ENABLE)
  io.InstrCommit.isRVC := false.B
  io.InstrCommit.scFailed := false.B
  io.InstrCommit.wen := RegEnable(wen, false.B, io.ENABLE)
  io.InstrCommit.wdata := RegEnable(wbdata, 0.U, io.ENABLE)
  io.InstrCommit.wdest := RegEnable(regdes, 0.U, io.ENABLE)

  io.ArchIntRegState.clock := clock
  io.ArchIntRegState.coreid := 0.U
  io.ArchIntRegState.gpr := io.gpr

  io.CSRState.clock := clock
  io.CSRState.coreid := 0.U
  io.CSRState.mstatus := csr.mstatus
  io.CSRState.mcause := csr.mcause
  io.CSRState.mepc := csr.mepc
  io.CSRState.sstatus := csr.mstatus & "h8000_0003_000d_e122".U
  io.CSRState.scause := 0.U
  io.CSRState.sepc := 0.U
  io.CSRState.satp := 0.U
  io.CSRState.mip := 0.U
  io.CSRState.mie := csr.mie
  io.CSRState.mscratch := csr.mscratch
  io.CSRState.sscratch := 0.U
  io.CSRState.mideleg := 0.U
  io.CSRState.medeleg := csr.medeleg
  io.CSRState.mtval := 0.U
  io.CSRState.stval := 0.U
  io.CSRState.mtvec := csr.mtvec
  io.CSRState.stvec := 0.U
  io.CSRState.priviledgeMode := 3.U

  io.ArchEvent.clock := clock
  io.ArchEvent.coreid := 0.U
  io.ArchEvent.intrNO := Mux(DifftestValid && TimerInterruptReg, 7.U, 0.U)
  io.ArchEvent.cause := 0.U
  io.ArchEvent.exceptionPC := RegEnable(PC, 0.U, io.ENABLE)
  io.ArchEvent.exceptionInst := RegEnable(Inst, 0.U, io.ENABLE)

  io.TrapEvent.clock := clock
  io.TrapEvent.coreid := 0.U
  io.TrapEvent.valid := DifftestValid && RegEnable(isHalt, false.B, io.ENABLE)
  io.TrapEvent.code := io.gpr(10)(7,0)
  io.TrapEvent.pc := RegEnable(PC, 0.U, io.ENABLE)
  io.TrapEvent.cycleCnt := cycleCnt
  io.TrapEvent.instrCnt := instrCnt

  io.ArchFpRegState.clock := clock
  io.ArchFpRegState.coreid := 0.U
  io.ArchFpRegState.fpr := RegInit(VecInit(Seq.fill(32)(0.U(64.W))))

}